{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import json\n",
    "import random\n",
    "import torch\n",
    "import numpy as np\n",
    "from transformers import AutoTokenizer, AutoConfig, DebertaV2ForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = AutoConfig.from_pretrained('mesolitica/malaysian-debertav2-base')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "config.problem_type = \"single_label_classification\"\n",
    "config.label2id = {'contradiction': 0, 'entailment': 1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at mesolitica/malaysian-debertav2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = DebertaV2ForSequenceClassification.from_pretrained('mesolitica/malaysian-debertav2-base', config = config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('mesolitica/malaysian-debertav2-base')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainable_parameters = [param for param in model.parameters() if param.requires_grad]\n",
    "trainer = torch.optim.AdamW(trainable_parameters, lr = 2e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, train_Y = [], []\n",
    "with open('shuffled-train.json') as fopen:\n",
    "    for l in fopen:\n",
    "        l = json.loads(l)\n",
    "        train_X.append(l['src'])\n",
    "        train_Y.append(l['label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16037"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_X, test_Y = [], []\n",
    "with open('shuffled-test.json') as fopen:\n",
    "    for l in fopen:\n",
    "        l = json.loads(l)\n",
    "        test_X.append(l['src'])\n",
    "        test_Y.append(l['label'])\n",
    "        \n",
    "len(test_X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 8\n",
    "epoch = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████| 136696/136696 [3:39:06<00:00, 10.40it/s, loss=0.0108]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 0.3771513450757229, dev_predicted: 0.8642\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████| 136696/136696 [3:39:28<00:00, 10.38it/s, loss=0.0225]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, loss: 0.32527882952862147, dev_predicted: 0.8572\n"
     ]
    }
   ],
   "source": [
    "best_dev_acc = -np.inf\n",
    "patient = 1\n",
    "current_patient = 0\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm(range(0, len(train_X), batch_size))\n",
    "    losses = []\n",
    "    for i in pbar:\n",
    "        trainer.zero_grad()\n",
    "        x = train_X[i: i + batch_size]\n",
    "        y = np.array(train_Y[i: i + batch_size])\n",
    "        \n",
    "        padded = tokenizer(x, truncation = True, padding = True, return_tensors = 'pt', max_length = 1024)\n",
    "        padded['labels'] = torch.from_numpy(y)\n",
    "        for k in padded.keys():\n",
    "            padded[k] = padded[k].cuda()\n",
    "        \n",
    "        padded.pop('token_type_ids', None)\n",
    "            \n",
    "        loss, pred = model(**padded, return_dict = False)\n",
    "        loss.backward()\n",
    "        \n",
    "        grad_norm = torch.nn.utils.clip_grad_norm_(trainable_parameters, 1.0)\n",
    "        trainer.step()\n",
    "        losses.append(float(loss))\n",
    "        pbar.set_postfix(loss = float(loss))\n",
    "        \n",
    "        \n",
    "    dev_predicted = []\n",
    "    for i in range(0, len(test_X[:10000]), batch_size):\n",
    "        x = test_X[i: i + batch_size]\n",
    "        y = np.array(test_Y[i: i + batch_size])\n",
    "        padded = tokenizer(x, truncation = True, padding = True, return_tensors = 'pt', max_length = 1024)\n",
    "        padded['labels'] = torch.from_numpy(y)\n",
    "        for k in padded.keys():\n",
    "            padded[k] = padded[k].cuda()\n",
    "            \n",
    "        padded.pop('token_type_ids', None)\n",
    "        \n",
    "        loss, pred = model(**padded, return_dict = False)\n",
    "        dev_predicted.append((pred.argmax(axis = 1).detach().cpu().numpy() == y).mean())\n",
    "        \n",
    "    dev_predicted = np.mean(dev_predicted)\n",
    "    \n",
    "    print(f'epoch: {e}, loss: {np.mean(losses)}, dev_predicted: {dev_predicted}')\n",
    "    \n",
    "    if dev_predicted >= best_dev_acc:\n",
    "        best_dev_acc = dev_predicted\n",
    "        current_patient = 0\n",
    "        model.save_pretrained('small')\n",
    "    else:\n",
    "        current_patient += 1\n",
    "    \n",
    "    if current_patient >= patient:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 2005/2005 [00:28<00:00, 70.29it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y = []\n",
    "for i in tqdm(range(0, len(test_X), batch_size)):\n",
    "    x = test_X[i: i + batch_size]\n",
    "    y = np.array(test_Y[i: i + batch_size])\n",
    "    padded = tokenizer(x, padding = 'longest', return_tensors = 'pt')\n",
    "    padded['labels'] = torch.from_numpy(y)\n",
    "    for k in padded.keys():\n",
    "        padded[k] = padded[k].cuda()\n",
    "\n",
    "    loss, pred = model(**padded, return_dict = False)\n",
    "    real_Y.extend(pred.argmax(axis = 1).detach().cpu().numpy().tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0    0.88042   0.83329   0.85621      8146\n",
      "           1    0.83692   0.88316   0.85942      7891\n",
      "\n",
      "    accuracy                        0.85783     16037\n",
      "   macro avg    0.85867   0.85823   0.85781     16037\n",
      "weighted avg    0.85901   0.85783   0.85779     16037\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "print(\n",
    "    metrics.classification_report(\n",
    "        real_Y, test_Y,\n",
    "        digits = 5\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[ 3657,    20,    29,  1541,   863,   482,  6558,  1523, 11768,    15,\n",
       "          7705,  1816,    21,    15,   704,  5359,   436, 15246,  3534, 13870,\n",
       "         28305,   365, 15559,    17,  3657,    21,    29, 15577, 19512,   418,\n",
       "          5079,  1267, 13904,  6332,   331, 20385,  8406,  8948,    17,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0],\n",
       "        [ 3657,    20,    29,  8080, 16965,    17,  3657,    21,    29, 23562,\n",
       "           954,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0],\n",
       "        [ 3657,    20,    29,  2548, 10328,    15,   384,  6362,  5513,   385,\n",
       "           331,  3959, 25744,   385,   384,  8924,  8928,  2026,  8399,  3123,\n",
       "           390,  4589,  2852,  6651, 10283, 10832,   887,  5556,    15,  2882,\n",
       "           574,  1170, 20065,   788, 18278,   325,   331, 10688,    17,  3657,\n",
       "            21,    29,  9660, 10832,   887,  5556,  2824,  1267,   390,   619,\n",
       "         15756,   860,   384,  6362,  5513,   564,  8924,  8928,  3959, 25744,\n",
       "            17],\n",
       "        [ 3657,    20,    29,  4667, 12644,  1184,   331, 10280,   860,  6846,\n",
       "           299,   677,   331, 12301,  2824,   619,  6107,   309,  1003,   369,\n",
       "          8973,   327,   514,   390,   384, 11897,  1526, 14139,   390,   619,\n",
       "         14623,   325, 11769,  4182,    17,  3657,    21,    29,  7121,  7208,\n",
       "           515,  2257, 26831,  6622,   450,  5765,   403,    17,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0],\n",
       "        [ 3657,    20,    29, 15560,   677,  8399,   619,  9703,   385, 19240,\n",
       "           498,   435,  3657,    21,    29, 10056, 13042,  5541,   629,  1354,\n",
       "          8164,  2826,  9245,   355,  2824,  9734,    17,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0'), 'labels': tensor([1, 0, 1, 0, 0], device='cuda:0')}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "padded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.model_input_names = ['input_ids', 'attention_mask']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/mesolitica/finetune-mnli-malaysian-debertav2-base/commit/12af7951dbbadfc331b35009457e0c4aca40bb75', commit_message='Upload tokenizer', commit_description='', oid='12af7951dbbadfc331b35009457e0c4aca40bb75', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.push_to_hub('mesolitica/finetune-mnli-malaysian-debertav2-base', safe_serialization = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c89f5a83dcb340fab32455335d7b9bea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cac2fb15904c4c5b856d8a2975341080",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/443M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/mesolitica/finetune-mnli-malaysian-debertav2-base/commit/4378e43229b01a9bd8de16fa672b0b6a4fae2d42', commit_message='Upload DebertaV2ForSequenceClassification', commit_description='', oid='4378e43229b01a9bd8de16fa672b0b6a4fae2d42', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.push_to_hub('mesolitica/finetune-mnli-malaysian-debertav2-base', safe_serialization = True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
